{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Useful util method to extract a specific token's representation from the last hidden states of a transformer model. \n",
    "\n",
    "Typically used for tasks where a singular representation of the entire sequence is required, such as sentence embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def last_token_pool(last_hidden_states, attention_mask):\n",
    "        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])\n",
    "        if left_padding:\n",
    "            return last_hidden_states[:, -1]\n",
    "        else:\n",
    "            sequence_lengths = attention_mask.sum(dim=1) - 1\n",
    "            batch_size = last_hidden_states.shape[0]\n",
    "            return last_hidden_states[\n",
    "                torch.arange(batch_size,\n",
    "                             device=last_hidden_states.device),\n",
    "                             sequence_lengths\n",
    "                             ]\n",
    "\n",
    "# usage of `last_token_pool()` method\n",
    "\n",
    "embeddings = last_token_pool(last_hidden_state, attention_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "📌 The function `last_token_pool` serves this purpose. It intelligently selects the last meaningful token from each sequence. This selection is based on the `attention_mask` to ensure padding tokens are not mistakenly chosen. Padding tokens do not contain useful information about the sequence content.\n",
    "\n",
    "📌 The method handles different padding strategies:\n",
    "   - If sequences are left-padded (padding at the beginning), the last token (rightmost) in the sequence is a meaningful token and is chosen.\n",
    "   - If sequences are right-padded or have variable lengths, it calculates the length of each sequence and selects the last meaningful token accordingly.\n",
    "\n",
    "📌 In transformer models, an attention mask is a binary matrix that indicates which tokens in the sequence are padding tokens (`0` for padding, `1` for real tokens). The shape of `attention_mask` is typically `[batch_size, sequence_length]`.\n",
    "\n",
    "📌 **Indexing `[:, -1]`**: This part selects the last column of the `attention_mask`. The `-1` index refers to the last element in a dimension, so this slice operation retrieves the padding status of the last token for each sequence in the batch.\n",
    "\n",
    " `attention_mask.shape[0]` gives us the batch size, i.e., the total number of sequences."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Full Explanation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First some basics\n",
    "\n",
    "📌 `batch_size`: In the context of neural networks, particularly those processing sequences like transformers, `batch_size` refers to the number of sequences processed together in one forward/backward pass. It is a crucial parameter in training, influencing memory usage and optimization dynamics. In the code, `batch_size` is derived from the shape of `last_hidden_states`, representing the number of sequences in the current batch.\n",
    "\n",
    "📌 `hidden state`: In transformer models, the hidden state refers to the output of the transformer layers. For each token in the input sequence, the transformer generates a vector representation (hidden state) capturing the contextual information. These representations are essential for downstream tasks, encoding both the intrinsic meaning of the token and its relation to other tokens in the sequence.\n",
    "\n",
    "📌 `sequence_lengths`: This represents the actual length of each sequence in the batch, excluding padding tokens. In sequences processed by transformers, padding is often used to equalize the lengths of different sequences in a batch. `sequence_lengths` is calculated to identify the position of the last real (non-padding) token in each sequence.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----------\n",
    "\n",
    "📌 The `last_token_pool` method first checks the padding direction used in the input sequences. The `left_padding` boolean variable is set to `True` if the last token of every sequence in the batch is a padding token. This is determined by checking if the sum of the last column in the `attention_mask` equals the batch size. If true, this implies all sequences are right-padded (common in English language processing tasks).\n",
    "\n",
    "📌 The line `left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])` is used to determine the padding direction in tokenized sequences. Let's break this down:\n",
    "\n",
    "- **Attention Mask Structure**: In transformer models, an attention mask is a binary matrix that indicates which tokens in the sequence are padding tokens (`0` for padding, `1` for real tokens). The shape of `attention_mask` is typically `[batch_size, sequence_length]`.\n",
    "\n",
    "- **Indexing `[:, -1]`**: This part selects the last column of the `attention_mask`. The `-1` index refers to the last element in a dimension, so this slice operation retrieves the padding status of the last token for each sequence in the batch.\n",
    "\n",
    "- **`.sum()` Operation**: By summing the values of the last column, we are effectively counting how many sequences have a non-padding token as their last element. Since non-padding tokens are represented by `1`, the sum will equal the number of sequences that end with a real token.\n",
    "\n",
    "- **Comparing to `attention_mask.shape[0]`**: `attention_mask.shape[0]` gives us the batch size, i.e., the total number of sequences. By comparing the sum to the batch size, we check whether all sequences in the batch have a non-padding token as their last element.\n",
    "\n",
    "📌 If `left_padding` is `True`, this implies that all sequences in the batch end with a real token. Meaning the sum of the values in the last column of the attention mask equals the number of sequences in the batch. i.e. Sequences are left-padded. This means that padding tokens (if present) are at the start of the sequences, and the meaningful content (real tokens) follows the padding.\n",
    "\n",
    "📌 Conversely, if `left_padding` is `False`, not all sequences end with a real token, implying that there is at least one sequence that is left-padded (the padding tokens are at the end). \n",
    "\n",
    "📌 This check is critical for correctly extracting the relevant token's representation from the sequences, especially when dealing with variable-length inputs in a batch. The method needs to know where the meaningful tokens end to avoid including padding representations in downstream tasks.\n",
    "\n",
    "-------\n",
    "\n",
    "📌 If `left_padding` is `True`, the method retrieves the last token of each sequence, which is the last non-padding token. It does this by selecting the last column (`-1`) of `last_hidden_states`.\n",
    "\n",
    "------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "📌 If `left_padding` is `False`, indicating sequences are right-padded, it calculates the actual length of each sequence in the batch by summing the `attention_mask` along each row and subtracting 1. This gives the index of the last non-padding token in each sequence.\n",
    "\n",
    "\n",
    "```py\n",
    "sequence_lengths = attention_mask.sum(dim=1) - 1\n",
    "batch_size = last_hidden_states.shape[0]\n",
    "return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]\n",
    "# The above line effectively gathers the representations of these last non-padding tokens.\n",
    "\n",
    "```\n",
    "\n",
    "📌 The logic of `sequence_lengths = attention_mask.sum(dim=1) - 1`:\n",
    "\n",
    "- `attention_mask.sum(dim=1)`: This sums the attention mask along the sequence length dimension (`dim=1`). Since the attention mask contains `1` for real tokens and `0` for padding tokens, the sum for each sequence gives the number of real tokens.\n",
    "- Subtracting `1`: The aim is to get the index of the last real token. Indexing in Python is zero-based, so to get the index of the last real token, we subtract `1` from the total count of real tokens.\n",
    "\n",
    "📌 `torch.arange(batch_size, device=last_hidden_states.device)`: This creates a tensor of indices from `0` to `batch_size-1`, matching the batch dimension of `last_hidden_states`.  Each index corresponds to a sequence in the batch. \n",
    "\n",
    "`device=last_hidden_states.device` ensures that the tensor is on the same device as `last_hidden_states` for compatibility.\n",
    "\n",
    "- `sequence_lengths` contains the index of the last real token for each sequence.\n",
    "\n",
    "- The indexing operation `[torch.arange(batch_size), sequence_lengths]`: This is advanced indexing in PyTorch. For each sequence in the batch (identified by `torch.arange(batch_size)`), it selects the hidden state corresponding to the last real token (identified by `sequence_lengths`). Essentially, it’s picking the hidden state of the last real token for each sequence in the batch, resulting in a tensor where each row corresponds to the last token's hidden state of each sequence.\n",
    "\n",
    "📌 When using these two tensors for indexing `last_hidden_states`, the operation selects a specific element from each sequence. The `torch.arange(batch_size)` part ensures that we are considering each sequence in the batch, and `sequence_lengths` picks the hidden state of the last real token in each sequence.\n",
    "\n",
    "- This form of indexing, known as advanced indexing in PyTorch, allows for selecting a different element (in this case, the hidden state of the last real token) from each row (or sequence) of the tensor `last_hidden_states`.\n",
    "\n",
    "- The result is a new tensor where each row corresponds to the hidden state of the last real token of each sequence, effectively condensing the essential information of each sequence into a single vector representation.\n",
    "\n",
    "-------\n",
    "\n",
    "📌 This process is crucial for tasks like semantic search or sentence classification, where the model must generate a single vector representation per input sequence. The selected token's hidden state is often a good summary of the sequence's overall meaning or content."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "usage of `last_token_pool()` method - \n",
    "\n",
    "`embeddings = self.last_token_pool(last_hidden_state, attention_mask)`\n",
    "\n",
    "WHY and How the above calculation of `embeddings` makes sense ?\n",
    "\n",
    "📌 Here, the purpose is to extract a representative embedding for the whole sequence, often crucial in tasks like semantic search or sentence classification.\n",
    "\n",
    "📌 The `last_hidden_state`, outputted by the transformer model, contains a rich representation for each token in the sequence. However, for many applications, a single vector representing the entire sequence is needed.\n",
    "\n",
    "📌 The approach used here is to select a specific token's representation from the `last_hidden_state`. The choice of token is pivotal as it should ideally encapsulate the context and content of the entire sequence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple demonstration of the advanced indexing concept using PyTorch:\n",
    "\n",
    "```python\n",
    "import torch\n",
    "\n",
    "# Sample hidden states tensor, simulating the output of a transformer layer\n",
    "# Assume each row is a sequence, and each column is a token's hidden state\n",
    "# Using sequential numbers for simplicity\n",
    "hidden_states = torch.arange(1, 21).view(4, 5)\n",
    "print(\"Hidden States Tensor:\\n\", hidden_states)\n",
    "\n",
    "# Assuming the actual lengths of sequences (excluding padding) are:\n",
    "sequence_lengths = torch.tensor([4, 3, 2, 5])  # lengths of each sequence\n",
    "# i.e. the first sequence has a length of 4, the 2nd one sequence has\n",
    "# length of 3 and so on\n",
    "\n",
    "# Batch size is the number of sequences, which is the number of rows in hidden_states\n",
    "batch_size = hidden_states.shape[0]\n",
    "\n",
    "# Using advanced indexing to select the last real token's hidden state for each sequence\n",
    "selected_hidden_states = hidden_states[torch.arange(batch_size), sequence_lengths - 1]\n",
    "\n",
    "print(\"\\nSelected Hidden States (Last Real Token per Sequence):\\n\", selected_hidden_states)\n",
    "\n",
    "```\n",
    "\n",
    "Output:\n",
    "\n",
    "```\n",
    "Hidden States Tensor:\n",
    " tensor([[ 1,  2,  3,  4,  5],\n",
    "        [ 6,  7,  8,  9, 10],\n",
    "        [11, 12, 13, 14, 15],\n",
    "        [16, 17, 18, 19, 20]])\n",
    "\n",
    "Selected Hidden States (Last Real Token per Sequence):\n",
    " tensor([ 4,  8, 12, 20])\n",
    "```\n",
    "\n",
    "📌 Explanation:\n",
    "\n",
    "- `hidden_states` represents a batch of sequences with their respective hidden states. Each row corresponds to a different sequence.\n",
    "\n",
    "- `sequence_lengths` contains the length of each sequence without padding.\n",
    "\n",
    "- The indexing `hidden_states[torch.arange(batch_size), sequence_lengths - 1]` selects the last real token's hidden state in each sequence. The `-1` accounts for zero-based indexing. \n",
    "\n",
    "- `torch.arange(batch_size)` generates indices for each sequence, and `sequence_lengths - 1` points to the last real token in each sequence.\n",
    "\n",
    "- The result `selected_hidden_states` contains the hidden state of the last real token for each sequence, aligned with the input order."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
